Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LlamaTokenizerFast] Refactor default llama #28881

Merged
merged 28 commits into from
Apr 23, 2024
Merged

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Feb 6, 2024

What does this PR do?

from transformers import LlamaTokenizerFast, AddedToken
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True)
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
tokenizer.tokenize("<REPR_END>inform<s>. Hey.       .")
['<REPR_END>', 'in', 'form', '<s>', '.', '▁Hey', '.', '▁▁▁▁▁▁', '▁.']
tokenizer.tokenize("inform<s>. Hey.       .")
['in', 'form', '<s>', '.', '▁Hey', '.', '▁▁▁▁▁▁', '▁.']

This requires huggingface/tokenizers#1476 to be merged.
Finally FIXES ALL THE remaining tokenizer issues!

Fixes #29617 as well (will update Gemma one)
fixes #28577
fixes #29617
fixes #29626
fixes #29694
fixes #29868
fixes #29872
fixes #30416
enabled by huggingface/tokenizers#1476

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker marked this pull request as ready for review March 22, 2024 09:07
@huggingface huggingface deleted a comment from github-actions bot Apr 18, 2024
@ArthurZucker
Copy link
Collaborator Author

Coming next release!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great to me

src/transformers/convert_slow_tokenizer.py Show resolved Hide resolved
src/transformers/convert_slow_tokenizer.py Outdated Show resolved Hide resolved
@@ -1329,7 +1329,7 @@ def tokenizer(self, proto):
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
user_defined_symbols = [
AddedToken(token, normalized=False, special=False) for token in proto.trainer_spec.user_defined_symbols
AddedToken(token, normalized=True, special=False) for token in proto.trainer_spec.user_defined_symbols
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a must? Unless with only the split this works.

@ArthurZucker ArthurZucker merged commit e34da3e into main Apr 23, 2024
22 checks passed
@ArthurZucker ArthurZucker deleted the refactor-default-llama branch April 23, 2024 21:13
itazap pushed a commit that referenced this pull request May 14, 2024
* push legacy to fast as well

* super strange

* Update src/transformers/convert_slow_tokenizer.py

* make sure we are BC

* fix Llama test

* nit

* revert

* more test

* style

* update

* small update w.r.t tokenizers

* nit

* don't split

* lol

* add a test for `add_prefix_space=False`

* fix gemma tokenizer as well

* update

* fix gemma

* nicer failures

* fixup

* update

* fix the example for legacy = False

* use `huggyllama/llama-7b` for the PR doctest

* nit

* use from_slow

* fix llama
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment